import numpy as np
import scipy.signal
from gym.spaces import Box, Discrete

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical

from spinup.exploration.nets_exploration import (
    OnPolicyCoherentLinear,
    NoisyLinear,
    OurNoisyLinear,
    PSNELinear,
)


def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)


def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]
    return nn.Sequential(*layers)


def coherent_mlp(
    sizes,
    activation,
    beta=0.01,
    std_w_init=0.017,
    std_a=0.1,
    output_activation=nn.Identity,
):
    layers = []
    for j in range(len(sizes) - 2):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]
    layers += [
        OnPolicyCoherentLinear(sizes[-2], sizes[-1], beta, std_w_init, std_a),
        output_activation(),
    ]
    return nn.Sequential(*layers)


def noisy_mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [NoisyLinear(sizes[j], sizes[j + 1]), act()]
    return nn.Sequential(*layers)


def our_noisy_mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes) - 2):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]
    layers += [OurNoisyLinear(sizes[-2], sizes[-1]), output_activation()]
    return nn.Sequential(*layers)


def PSNE_mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [PSNELinear(sizes[j], sizes[j + 1]), act()]
    return nn.Sequential(*layers)


def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.

    input: 
        vector x, 
        [x0, 
         x1, 
         x2]

    output:
        [x0 + discount * x1 + discount^2 * x2,  
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]


class Actor(nn.Module):
    def _distribution(self, obs):
        raise NotImplementedError

    def _log_prob_from_distribution(self, pi, act):
        raise NotImplementedError

    def forward(self, obs, act=None):
        # Produce action distributions for given observations, and
        # optionally compute the log likelihood of given actions under
        # those distributions.
        pi = self._distribution(obs)
        logp_a = None
        if act is not None:
            logp_a = self._log_prob_from_distribution(pi, act)
        return pi, logp_a


class MLPCategoricalActor(Actor):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)

    def _distribution(self, obs):
        logits = self.logits_net(obs)
        return Categorical(logits=logits)

    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act)


class MLPGaussianActor(Actor):
    def __init__(
        self,
        obs_dim,
        act_dim,
        hidden_sizes,
        activation,
        exploration="action",
        std_a=0.1,
        beta=0.01,
        std_w_init=0.017,
    ):
        super().__init__()
        if exploration == "action":
            self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
            log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
            self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
        elif exploration == "coherent":
            self.mu_net = coherent_mlp(
                [obs_dim] + list(hidden_sizes) + [act_dim],
                activation,
                beta,
                std_w_init,
                std_a,
            )
            self.std_a = std_a
        elif exploration == "noisy":
            self.mu_net = noisy_mlp(
                [obs_dim] + list(hidden_sizes) + [act_dim], activation
            )
            self.std_a = std_a
        elif exploration == "our_noisy":
            self.mu_net = our_noisy_mlp(
                [obs_dim] + list(hidden_sizes) + [act_dim], activation
            )
            self.std_a = std_a
        elif exploration == "PSNE":
            self.mu_net = PSNE_mlp(
                [obs_dim] + list(hidden_sizes) + [act_dim], activation
            )
            self.std_a = std_a
        else:
            raise TypeError(
                "Legit types of exploration: action, coherent, noisy, our_noisy, or PSNE!"
            )

        self.exploration = exploration

    def _distribution(self, obs):
        mu = self.mu_net(obs)
        if self.exploration == "action":
            std = torch.exp(self.log_std)
        else:
            std = self.std_a
        return Normal(mu, std)

    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act).sum(
            axis=-1
        )  # Last axis sum needed for Torch Normal distribution


class MLPCritic(nn.Module):
    def __init__(self, obs_dim, hidden_sizes, activation):
        super().__init__()
        self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs):
        return torch.squeeze(
            self.v_net(obs), -1
        )  # Critical to ensure v has right shape.


class MLPActorCritic(nn.Module):
    def __init__(
        self,
        observation_space,
        action_space,
        hidden_sizes=(64, 64),
        activation=nn.Tanh,
        exploration="action",
        std_a=0.1,
        beta=0.01,
        std_w_init=0.017,
    ):
        super().__init__()

        obs_dim = observation_space.shape[0]

        # policy builder depends on action space
        if isinstance(action_space, Box):
            self.pi = MLPGaussianActor(
                obs_dim,
                action_space.shape[0],
                hidden_sizes,
                activation,
                exploration,
                std_a,
                beta,
                std_w_init,
            )
        elif isinstance(action_space, Discrete):
            self.pi = MLPCategoricalActor(
                obs_dim, action_space.n, hidden_sizes, activation
            )

        # build value function
        self.v = MLPCritic(obs_dim, hidden_sizes, activation)

    def step(self, obs, test=False):
        with torch.no_grad():
            # train
            if not test:
                pi = self.pi._distribution(obs)
                a = pi.sample()
                if self.pi.exploration == "coherent":
                    self.pi.mu_net[-2].get_a(a)
                    marginal_pi = self.pi.mu_net[-2].get_marginal_pi()
                    logp_a = self.pi._log_prob_from_distribution(marginal_pi, a)
                else:
                    logp_a = self.pi._log_prob_from_distribution(pi, a)
                v = self.v(obs)
                return a.numpy(), v.numpy(), logp_a.numpy()
            # test
            else:
                pi = self.pi._distribution(obs)
                a = pi.mean
                return a.numpy()

    def act(self, obs):
        return self.step(obs)[0]

    def adapt_param_noise(self, kl_if_with_noise):
        if self.pi.exploration == "PSNE":
            for i in range(len(self.pi.mu_net)):
                if isinstance(self.pi.mu_net[i], PSNELinear):
                    self.pi.mu_net[i].adapt_std_w(kl_if_with_noise)

    def set_pi_if_with_noise(self, if_with_noise):
        if self.pi.exploration == "coherent":
            self.pi.mu_net[-2].if_with_noise = if_with_noise
        elif self.pi.exploration == "noisy":
            for i in range(len(self.pi.mu_net)):
                if isinstance(self.pi.mu_net[i], NoisyLinear):
                    self.pi.mu_net[i].if_with_noise = if_with_noise
        elif self.pi.exploration == "our_noisy":
            self.pi.mu_net[-2].if_with_noise = if_with_noise
        elif self.pi.exploration == "PSNE":
            for i in range(len(self.pi.mu_net)):
                if isinstance(self.pi.mu_net[i], PSNELinear):
                    self.pi.mu_net[i].if_with_noise = if_with_noise

    def reset(self):
        if self.pi.exploration == "coherent":
            self.pi.mu_net[-2].reset()
        elif self.pi.exploration == "noisy":
            for i in range(len(self.pi.mu_net)):
                if isinstance(self.pi.mu_net[i], NoisyLinear):
                    self.pi.mu_net[i].reset()
        elif self.pi.exploration == "our_noisy":
            self.pi.mu_net[-2].reset()
        elif self.pi.exploration == "PSNE":
            for i in range(len(self.pi.mu_net)):
                if isinstance(self.pi.mu_net[i], PSNELinear):
                    self.pi.mu_net[i].reset()
